{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stable Model Training with monitoring through Weights & Biases"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTES:  \n",
    "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
    "* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.\n",
    "* Training with this notebook has been logged and monitored through [Weights & Biases](https://www.wandb.com/). Refer to [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).\n",
    "* It is **highly** recommended to use a 11 Go GPU to run this notebook. Anything lower will require to reduce the batch size (leading to moro instability) or use of \"Large Model Support\" from IBM WML-CE (not so easy to setup). An alternative is to rent ressources online."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install W&B Callback\n",
    "#!pip install wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#NOTE:  This must be the first call in order to work properly!\n",
    "from deoldify import device\n",
    "from deoldify.device_id import DeviceId\n",
    "#choices:  CPU, GPU0...GPU7\n",
    "device.set(device=DeviceId.GPU0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import fastai\n",
    "from fastai import *\n",
    "from fastai.vision import *\n",
    "from fastai.vision.gan import *\n",
    "from deoldify.generators import *\n",
    "from deoldify.critics import *\n",
    "from deoldify.dataset import *\n",
    "from deoldify.loss import *\n",
    "from deoldify.save import *\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from PIL import ImageFile\n",
    "from torch.utils.data.sampler import RandomSampler, SequentialSampler\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "from wandb.fastai import WandbCallback"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up W&B: checks user can connect to W&B servers\n",
    "# Note: set up API key the first time\n",
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset can be downloaded from https://www.kaggle.com/c/imagenet-object-localization-challenge/data\n",
    "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
    "path_hr = path\n",
    "path_lr = path/'bandw'\n",
    "\n",
    "proj_id = 'StableModel'\n",
    "\n",
    "gen_name = proj_id + '_gen'\n",
    "pre_gen_name = gen_name + '_0'\n",
    "crit_name = proj_id + '_crit'\n",
    "\n",
    "name_gen = proj_id + '_image_gen'\n",
    "path_gen = path/name_gen\n",
    "\n",
    "nf_factor = 2\n",
    "pct_start = 1e-8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iterating through the dataset\n",
    "\n",
    "The dataset is very large and it would take a long time to iterate through all the samples at each epoch.\n",
    "\n",
    "We use custom samplers in order to limit epochs to subsets of data while still iterating slowly through the entire dataset (epoch after epoch). This let us run the validation loop more often where we log metrics as well as prediction samples on validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce quantity of samples per training epoch\n",
    "# Adapted from https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10\n",
    "\n",
    "@classmethod\n",
    "def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,\n",
    "            val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,\n",
    "            device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':\n",
    "    \"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`\"\n",
    "    datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
    "    val_bs = ifnone(val_bs, bs)\n",
    "    if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]\n",
    "    dls = [DataLoader(d, b, sampler=sa(d), drop_last=sh, num_workers=num_workers, **dl_kwargs) for d,b,sh,sa in\n",
    "            zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False), sampler) if d is not None]\n",
    "    return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n",
    "\n",
    "ImageDataBunch.create = create\n",
    "ImageImageList._bunch = ImageDataBunch\n",
    "\n",
    "class FixedLenRandomSampler(RandomSampler):\n",
    "    def __init__(self, data_source, epoch_size):\n",
    "        super().__init__(data_source)\n",
    "        self.epoch_size = epoch_size\n",
    "        self.not_sampled = np.array([True]*len(data_source))\n",
    "    \n",
    "    @property\n",
    "    def reset_state(self): self.not_sampled[:] = True\n",
    "        \n",
    "    def __iter__(self):\n",
    "        ns = sum(self.not_sampled)\n",
    "        idx_last = []\n",
    "        if ns >= len(self):\n",
    "            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()\n",
    "            if ns == len(self): self.reset_state\n",
    "        else:\n",
    "            idx_last = np.where(self.not_sampled)[0].tolist()\n",
    "            self.reset_state\n",
    "            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()\n",
    "        self.not_sampled[idx] = False\n",
    "        idx = [*idx_last, *idx]\n",
    "        return iter(idx)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.epoch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(bs:int, sz:int, keep_pct=1.0, random_seed=None, valid_pct=0.2, epoch_size=1000):\n",
    "    \n",
    "    # Create samplers\n",
    "    train_sampler = partial(FixedLenRandomSampler, epoch_size=epoch_size)\n",
    "    samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]\n",
    "\n",
    "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, random_seed=random_seed,\n",
    "                             keep_pct=keep_pct, samplers=samplers, valid_pct=valid_pct)\n",
    "\n",
    "# Function modified to allow use of custom samplers\n",
    "def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None,\n",
    "        keep_pct:float=1.0, num_workers:int=8, samplers=None, valid_pct=0.2, xtra_tfms=[])->ImageDataBunch:\n",
    "    src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB')\n",
    "        .use_partial_data(sample_pct=keep_pct, seed=random_seed)\n",
    "        .split_by_rand_pct(valid_pct, seed=random_seed))\n",
    "    data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))\n",
    "        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)\n",
    "        .databunch(bs=bs, num_workers=num_workers, sampler=samplers, no_check=True)\n",
    "        .normalize(imagenet_stats, do_y=True))\n",
    "    data.c = 3\n",
    "    return data\n",
    "\n",
    "# Function to limit amount of data in critic\n",
    "def filter_data(pct=1.0):\n",
    "    def _f(fname):\n",
    "        if 'test' in str(fname):\n",
    "            if np.random.random_sample() > pct:\n",
    "                return False\n",
    "        return True\n",
    "    return _f\n",
    "\n",
    "def get_crit_data(classes, bs, sz, pct=1.0):\n",
    "    src = ImageList.from_folder(path, include=classes, recurse=True).filter_by_func(filter_data(pct)).split_by_rand_pct(0.1)\n",
    "    ll = src.label_from_folder(classes=classes)\n",
    "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
    "           .databunch(bs=bs).normalize(imagenet_stats))\n",
    "    return data\n",
    "\n",
    "def create_training_images(fn,i):\n",
    "    dest = path_lr/fn.relative_to(path_hr)\n",
    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
    "    img.save(dest)  \n",
    "    \n",
    "def save_preds(dl):\n",
    "    i=0\n",
    "    names = dl.dataset.items    \n",
    "    for b in tqdm(dl):\n",
    "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
    "        for o in preds:\n",
    "            o.save(path_gen/names[i].name)\n",
    "            i += 1\n",
    "    \n",
    "def save_gen_images(keep_pct):\n",
    "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
    "    path_gen.mkdir(exist_ok=True)\n",
    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\n",
    "    save_preds(data_gen.fix_dl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create black and white training images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Only runs if the directory isn't already created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not path_lr.exists():\n",
    "    il = ImageList.from_folder(path_hr)\n",
    "    parallel(create_training_images, il.items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of black & white images\n",
    "data_size = len(list(path_lr.rglob('*.*')))\n",
    "print('Number of black & white images:', data_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-train generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE\n",
    "Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 64px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Init logging of a new run\n",
    "wandb.init(tags=['Pre-train Gen'])  # tags are optional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=88\n",
    "sz=64\n",
    "\n",
    "# Define target number of training/validation samples as well as number of epochs\n",
    "epoch_train_size = 100 * bs\n",
    "epoch_valid_size = 10 * bs\n",
    "valid_pct = epoch_valid_size / data_size\n",
    "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n",
    "\n",
    "# Log hyper parameters\n",
    "wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz,\n",
    "                     \"Step 1 - epoch size\": epoch_train_size, \"Step 1 - number epochs\": number_epochs})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.callback_fns.append(partial(WandbCallback,\n",
    "                                      input_type='images'))  # log prediction samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(number_epochs, pct_start=0.8, max_lr=slice(1e-3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 128px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=20\n",
    "sz=128\n",
    "\n",
    "# Define target number of training/validation samples as well as number of epochs\n",
    "epoch_train_size = 100 * bs\n",
    "epoch_valid_size = 10 * bs\n",
    "valid_pct = epoch_valid_size / data_size\n",
    "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n",
    "\n",
    "# Log hyper parameters\n",
    "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz,\n",
    "                     \"Step 2 - epoch size\": epoch_train_size, \"Step 2 - number epochs\": number_epochs})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 192px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192\n",
    "\n",
    "# Define target number of training/validation samples as well as number of epochs\n",
    "epoch_train_size = 100 * bs\n",
    "epoch_valid_size = 10 * bs\n",
    "valid_pct = epoch_valid_size / data_size\n",
    "number_epochs = (data_size - epoch_valid_size) // epoch_train_size // 2  # Training is long - we use half of data\n",
    "\n",
    "# Log hyper parameters\n",
    "wandb.config.update({\"Step 3 - batch size\": bs, \"Step 3 - image size\": sz,\n",
    "                     \"Step 3 - epoch size\": epoch_train_size, \"Step 3 - number epochs\": number_epochs})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# End logging of current session run\n",
    "# Note: this is optional and would be automatically triggered when stopping the kernel\n",
    "wandb.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repeatable GAN Cycle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE\n",
    "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_checkpoint_num = 0\n",
    "checkpoint_num = old_checkpoint_num + 1\n",
    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save Generated Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192\n",
    "\n",
    "# Define target number of training/validation samples as well as number of epochs\n",
    "epoch_train_size = 100 * bs\n",
    "epoch_valid_size = 10 * bs\n",
    "valid_pct = epoch_valid_size / data_size\n",
    "number_epochs = (data_size - epoch_valid_size) // epoch_train_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_gen_images(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pretrain Critic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if old_checkpoint_num == 0:\n",
    "    \n",
    "    # Init logging of a new run\n",
    "    wandb.init(tags=['Pre-train Crit'])  # tags are optional\n",
    "    \n",
    "    bs=64\n",
    "    sz=128\n",
    "    learn_gen=None\n",
    "    \n",
    "    # Log hyper parameters\n",
    "    wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz})\n",
    "\n",
    "    gc.collect()    \n",
    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
    "    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
    "    learn_crit = colorize_crit_learner(data=data_crit, nf=256)\n",
    "    learn_crit.callback_fns.append(partial(WandbCallback))  # log prediction samples\n",
    "    learn_crit.fit_one_cycle(6, 1e-3)\n",
    "    learn_crit.save(crit_old_checkpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=16\n",
    "sz=192\n",
    "\n",
    "# Log hyper parameters\n",
    "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit.fit_one_cycle(4, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit.save(crit_new_checkpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# free up memory\n",
    "learn_crit=None\n",
    "learn_gen=None\n",
    "learn=None\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set old_checkpoint_num to last iteration\n",
    "old_checkpoint_num = 0\n",
    "save_checkpoints = False\n",
    "batch_per_epoch = 200\n",
    "\n",
    "checkpoint_num = old_checkpoint_num + 1\n",
    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)   \n",
    "\n",
    "if False:   # need only to do it once\n",
    "        \n",
    "    # Generate data\n",
    "    print('Generating data…')\n",
    "    bs=8\n",
    "    sz=192\n",
    "    epoch_train_size = batch_per_epoch * bs\n",
    "    epoch_valid_size = batch_per_epoch * bs // 10\n",
    "    valid_pct = epoch_valid_size / data_size\n",
    "    data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
    "    learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
    "    save_gen_images(0.02)\n",
    "\n",
    "    # Pre-train critic\n",
    "    print('Pre-training critic…')\n",
    "    bs=16\n",
    "    sz=192\n",
    "\n",
    "    len_test = len(list((path / 'test').rglob('*.*')))\n",
    "    len_gen = len(list((path / name_gen).rglob('*.*')))\n",
    "    keep_test_pct = len_gen / len_test * 2\n",
    "\n",
    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
    "    learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\n",
    "    learn_crit.fit_one_cycle(1, 1e-4)\n",
    "    learn_crit.save(crit_new_checkpoint_name)\n",
    "\n",
    "# Creating GAN\n",
    "print('Creating GAN…')\n",
    "sz=192\n",
    "bs=8\n",
    "lr_GAN=2e-5\n",
    "epoch_train_size = batch_per_epoch * bs\n",
    "epoch_valid_size = batch_per_epoch * bs // 10\n",
    "valid_pct = epoch_valid_size / data_size\n",
    "len_test = len(list((path / 'test').rglob('*.*')))\n",
    "len_gen = len(list((path / name_gen).rglob('*.*')))\n",
    "keep_test_pct = len_gen / len_test * 2\n",
    "\n",
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\n",
    "data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
    "learn.callback_fns.append(partial(WandbCallback, input_type='images', seed=None, save_model=False))\n",
    "learn.data = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
    "\n",
    "# Start logging to W&B\n",
    "wandb.init(tags=['GAN'])\n",
    "wandb.config.update({\"learning rate\": lr_GAN})  \n",
    "\n",
    "# Run the loop until satisfied with the results\n",
    "while True:\n",
    "\n",
    "    # Current loop\n",
    "    checkpoint_num = old_checkpoint_num + 1\n",
    "    gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
    "    gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
    "    crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
    "    crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)      \n",
    "    \n",
    "    \n",
    "    # GAN for 10 epochs between each checkpoint\n",
    "    try:\n",
    "        learn.fit(1, lr_GAN)\n",
    "    except:\n",
    "        # Sometimes we get an error for some unknown reason during callbacks\n",
    "        learn.callback_fns[-1](learn).on_epoch_end(old_checkpoint_num, None, [])\n",
    "        \n",
    "    if save_checkpoints:\n",
    "        learn_crit.save(crit_new_checkpoint_name)\n",
    "        learn_gen.save(gen_new_checkpoint_name)\n",
    "    \n",
    "    old_checkpoint_num += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# End logging of current session run\n",
    "# Note: this is optional and would be automatically triggered when stopping the kernel\n",
    "wandb.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
